{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# TFPARN (Transformer-based Focal-Pairwise Attentive Ranking Network) for Anti-Spoofing: Complete Technical Documentation\n",
    "\n",
    "This notebook provides comprehensive documentation of TFPARN (Transformer-based Focal-Pairwise Attentive Ranking Network) and its application in the ASVspoof5 challenge (Audio Anti-Spoofing Detection). The system distinguishes between genuine human speech (bonafide) and AI-generated synthetic speech (spoof)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6240bcbt7mi",
   "source": "## 1. Background\n\n### 1.1 ASVspoof5 Challenge Overview\n\nThe **ASVspoof5 (Automatic Speaker Verification Spoofing and Countermeasures Challenge 5)** is a competition focused on detecting AI-generated synthetic speech. With the rapid advancement of speech synthesis technology, distinguishing between genuine human speech and sophisticated deepfakes has become critical for security applications.\n\n**Challenge Goal:** Build models that can accurately classify audio samples as:\n- **Bonafide (Label=1):** Genuine human speech\n- **Spoof (Label=0):** AI-generated synthetic speech\n\n**Key Challenges:**\n- Diverse attack types (TTS, VC, various codec compressions)\n- Class imbalance in training data\n- Need for robust generalization to unseen attack types\n- Real-world audio quality variations\n\n### 1.2 Evaluation Metrics\n\nThe challenge uses EER (Equal Error Rate), minDCF (Minimum Detection Cost Function), and CLLR (Calibrated Log-Likelihood Ratio) as primary metrics. Lower values indicate better performance.",
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "id": "ndtv2pbq5e",
   "source": [
    "## 2. Environment Setup\n",
    "\n",
    "### 2.1 Dependencies\n",
    "\n",
    "The project requires PyTorch and audio processing libraries:\n",
    "\n",
    "```\n",
    "torch~=2.9.0+cu130\n",
    "numpy~=2.1.2\n",
    "scikit-learn~=1.7.2\n",
    "tqdm~=4.67.1\n",
    "torchaudio~=2.9.0\n",
    "soundfile~=0.13.1\n",
    "scipy~=1.15.3\n",
    "```\n",
    "\n",
    "**Installation:**\n",
    "```bash\n",
    "pip install -r requirements.txt\n",
    "```\n",
    "\n",
    "### 2.2 Hardware Requirements\n",
    "\n",
    "**GPU Memory Requirements (for training):**\n",
    "- **Batch size 64:** ~8GB VRAM\n",
    "- **Batch size 96:** ~10GB VRAM  \n",
    "- **Batch size 128:** ~12GB VRAM\n",
    "\n",
    "**Recommended Configuration:**\n",
    "- GPU: NVIDIA RTX 3090 / 4090 or better\n",
    "- RAM: 32GB+ system memory\n",
    "- Storage: ~100GB for datasets\n",
    "\n",
    "**CPU-only mode** is supported but significantly slower for training."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "id": "olxxoem7mu",
   "source": [
    "## 3. Model Architecture\n",
    "\n",
    "### 3.1 Complete Pipeline Overview\n",
    "\n",
    "The model follows a complete Transformer-based architecture:\n",
    "\n",
    "```\n",
    "Raw Waveform -> Log-Mel Spectrogram -> Transformer Encoder -> Pooling -> Classification\n",
    "```\n",
    "\n",
    "**Architecture Stages:**\n",
    "\n",
    "1. **Frontend:** Log-Mel Spectrogram extraction (in-model computation)\n",
    "2. **Embedding:** Linear projection + Layer Normalization\n",
    "3. **Positional Encoding:** Sinusoidal positional embeddings\n",
    "4. **Backbone:** Multi-layer Transformer Encoder (self-attention)\n",
    "5. **Pooling:** Mean/Attention/Top-k pooling with masking (Attention pooling is used in TFPARN)\n",
    "6. **Classification Head:** 2-layer MLP -> Binary logits\n",
    "\n",
    "### 3.2 Model Configuration\n",
    "\n",
    "```python\n",
    "from dataclasses import dataclass\n",
    "\n",
    "@dataclass\n",
    "class SpeechClassifierArgs:\n",
    "    # Mel Spectrogram parameters\n",
    "    n_mels: int = 128          # Number of mel filterbanks\n",
    "    n_fft: int = 768           # FFT window size\n",
    "    hop_length: int = 160      # Hop length for STFT\n",
    "    sample_rate: int = 16000   # Audio sample rate\n",
    "    \n",
    "    # Transformer parameters\n",
    "    d_model: int = 256         # Model dimension\n",
    "    nhead: int = 8             # Number of attention heads\n",
    "    num_layers: int = 6        # Number of Transformer layers\n",
    "    dim_feedforward: int = 1024 # FFN hidden dimension\n",
    "    dropout: float = 0.3       # Dropout probability\n",
    "    activation: str = \"relu\"   # Activation function\n",
    "    \n",
    "    # Pooling method: \"mean\", \"attention\", \"top-k\"\n",
    "    pooling_method: str = \"attention\"\n",
    "    top_k_ratio: float = 0.3   # For top-k pooling\n",
    "```\n",
    "\n",
    "### 3.3 Architecture Code Example\n",
    "\n",
    "```python\n",
    "from model import create_model, SpeechClassifierArgs\n",
    "\n",
    "# Create model with default configuration\n",
    "args = SpeechClassifierArgs()\n",
    "model = create_model(args)\n",
    "\n",
    "# Or customize architecture\n",
    "args = SpeechClassifierArgs(\n",
    "    n_mels=160,\n",
    "    n_fft=1024,\n",
    "    d_model=256,\n",
    "    num_layers=6,\n",
    "    pooling_method=\"mean\"\n",
    ")\n",
    "model = create_model(args)\n",
    "\n",
    "# Move to device\n",
    "model = model.to(device)\n",
    "\n",
    "# Count parameters\n",
    "num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "print(f\"Trainable parameters: {num_params:,}\")\n",
    "```\n",
    "\n",
    "### 3.4 Key Architecture Features\n",
    "\n",
    "**1. In-Model Mel Spectrogram Computation:**\n",
    "- No need for preprocessing\n",
    "- Mel filterbank registered as buffer (not trainable)\n",
    "- Consistent processing during training and inference\n",
    "\n",
    "**2. Positional Encoding:**\n",
    "```python\n",
    "class PositionalEncoding(nn.Module):\n",
    "    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):\n",
    "        super().__init__()\n",
    "        self.dropout = nn.Dropout(p=dropout)\n",
    "        \n",
    "        # Sinusoidal positional encoding\n",
    "        position = torch.arange(max_len).unsqueeze(1)\n",
    "        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n",
    "        \n",
    "        pe = torch.zeros(1, max_len, d_model)\n",
    "        pe[0, :, 0::2] = torch.sin(position * div_term)\n",
    "        pe[0, :, 1::2] = torch.cos(position * div_term)\n",
    "        \n",
    "        self.register_buffer('pe', pe)\n",
    "```\n",
    "\n",
    "**3. Flexible Pooling Strategies:**\n",
    "\n",
    "- **Mean Pooling:** Average all frame embeddings (default, fastest)\n",
    "- **Attention Pooling:** Learned attention weights for aggregation\n",
    "- **Top-k Pooling:** Select top-k frames by L2 norm\n",
    "\n",
    "**4. Classification Head:**\n",
    "```python\n",
    "self.classifier = nn.Sequential(\n",
    "    nn.Linear(d_model, d_model // 2),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(dropout),\n",
    "    nn.Linear(d_model // 2, 2)  # Binary: [spoof, bonafide]\n",
    ")\n",
    "```\n",
    "\n",
    "### 3.5 Forward Pass\n",
    "\n",
    "```python\n",
    "# Input: [B, 1, T] raw waveform (mono)\n",
    "# Output: [B, 2] logits [spoof_score, bonafide_score]\n",
    "\n",
    "waveforms = batch['waveforms'].to(device)  # [B, 1, 64000]\n",
    "logits = model(waveforms)  # [B, 2]\n",
    "\n",
    "# Get predictions\n",
    "probs = torch.softmax(logits, dim=1)\n",
    "predictions = torch.argmax(logits, dim=1)  # 0=spoof, 1=bonafide\n",
    "```"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "id": "ywfwto4a1so",
   "source": [
    "## 4. Data Processing Pipeline\n",
    "\n",
    "### 4.1 Data Loading Overview\n",
    "\n",
    "The data pipeline handles:\n",
    "- Protocol file parsing\n",
    "- Audio loading (FLAC format)\n",
    "- Fixed-length processing (crop/repeat strategy)\n",
    "- RawBoost augmentation (training only)\n",
    "- Test-Time Augmentation (TTA) for inference\n",
    "\n",
    "### 4.2 Protocol File Format\n",
    "\n",
    "ASVspoof5 protocol files contain 10 columns:\n",
    "\n",
    "```\n",
    "speaker_id  file_name  gender  codec  codec_q  codec_seed  attack_tag  attack_label  KEY  tmp\n",
    "```\n",
    "\n",
    "**Label Mapping:**\n",
    "- `bonafide` → 1 (genuine human speech)\n",
    "- `spoof` → 0 (AI-generated speech)\n",
    "\n",
    "### 4.3 Audio Processing Strategy\n",
    "\n",
    "**Fixed-Length Processing:**\n",
    "```python\n",
    "# Target duration: 4.0 seconds at 16 kHz = 64,000 samples\n",
    "duration_sec = 4.0\n",
    "sample_rate = 16000\n",
    "target_length = int(duration_sec * sample_rate)  # 64,000\n",
    "\n",
    "# If audio is longer: crop (random for train, center for val/test)\n",
    "# If audio is shorter: repeat-concatenate then crop\n",
    "```\n",
    "\n",
    "**Normalization:**\n",
    "- Convert to mono (if stereo)\n",
    "- Normalize amplitude to \\[-1, 1\\]\n",
    "- Resample to 16kHz (if needed)\n",
    "\n",
    "### 4.4 RawBoost Data Augmentation\n",
    "\n",
    "RawBoost applies various augmentations to improve generalization:\n",
    "\n",
    "**Algorithm 1: Linear/Nonlinear Convolution**\n",
    "```python\n",
    "# Generate random FIR filter\n",
    "N_fir = np.random.randint(5, 15)\n",
    "h = np.random.randn(N_fir)\n",
    "h = h / np.sum(np.abs(h))\n",
    "\n",
    "# Apply convolution\n",
    "x_conv = signal.convolve(x, h, mode='same')\n",
    "\n",
    "# Optional nonlinear distortion\n",
    "if np.random.rand() > 0.5:\n",
    "    alpha = np.random.uniform(0.1, 0.5)\n",
    "    x_conv = np.tanh(alpha * x_conv)\n",
    "```\n",
    "\n",
    "**Algorithm 2: IIR Filtering**\n",
    "```python\n",
    "# Randomly select filter type\n",
    "filter_type = np.random.choice(['lowpass', 'highpass', 'bandpass'])\n",
    "\n",
    "# Apply butterworth filter\n",
    "if filter_type == 'lowpass':\n",
    "    cutoff = np.random.uniform(1000, 4000)  # Hz\n",
    "    b, a = signal.butter(4, cutoff / (sample_rate / 2), btype='low')\n",
    "```\n",
    "\n",
    "**Algorithm 3: Additive Noise**\n",
    "```python\n",
    "# Add stationary noise with random SNR (10-40 dB)\n",
    "snr_db = np.random.uniform(10, 40)\n",
    "noise = np.random.randn(len(x))\n",
    "x_noisy = x + noise * np.sqrt(signal_power / (10 ** (snr_db / 10)))\n",
    "```\n",
    "\n",
    "### 4.5 Test-Time Augmentation (TTA)\n",
    "\n",
    "TTA generates multiple crops per sample for variance reduction:\n",
    "\n",
    "```python\n",
    "# For inference: generate 5 overlapping crops per audio\n",
    "# Model predicts on all crops, then averages logits\n",
    "\n",
    "class TTADataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, base_dataset, num_crops=5):\n",
    "        self.base_dataset = base_dataset\n",
    "        self.num_crops = num_crops\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        # Generate multiple crops with 50% overlap\n",
    "        waveforms = self.base_dataset.generate_tta_crops(\n",
    "            waveform, num_crops=self.num_crops\n",
    "        )  # [num_crops, C, T]\n",
    "        return {\n",
    "                \"waveforms\": waveforms,  # [num_crops, C, T]\n",
    "                \"length\": waveforms.shape[-1],\n",
    "                \"label\": label,\n",
    "                \"speaker_id\": item['speaker_id'],\n",
    "                \"attack_label\": item['attack_label'],\n",
    "                \"audio_path\": str(audio_path)\n",
    "        }\n",
    "\n",
    "# During inference\n",
    "logits_crops = model(waveforms_flat)  # [B*num_crops, 2]\n",
    "logits = logits_crops.view(B, num_crops, 2).mean(dim=1)  # Average\n",
    "```\n",
    "\n",
    "### 4.6 DataLoader Creation\n",
    "\n",
    "```python\n",
    "from data_process import make_loaders, DefaultArgs\n",
    "\n",
    "# Configure data loading\n",
    "args = DefaultArgs()\n",
    "args.train_data_dir = \"path/to/train/flac/\"\n",
    "args.train_protocol_dir = \"path/to/train.tsv\"\n",
    "args.batch_size = 96\n",
    "args.num_workers = 8\n",
    "args.use_rawboost = True  # Enable RawBoost for training\n",
    "args.rawboost_prob = 0.5  # Apply to 50% of samples\n",
    "args.use_tta = True       # Enable TTA for dev/eval\n",
    "args.tta_num_crops = 5    # Number of crops per sample\n",
    "\n",
    "# Create dataloaders\n",
    "train_loader, dev_loader, eval_loader = make_loaders(args)\n",
    "\n",
    "# Batch structure\n",
    "for batch in train_loader:\n",
    "    waveforms = batch['waveforms']  # [B, 1, 64000]\n",
    "    labels = batch['labels']        # [B]\n",
    "    lengths = batch['lengths']      # [B]\n",
    "    break\n",
    "```"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "id": "drf6zgo6xv7",
   "source": [
    "## 5. Training Pipeline\n",
    "\n",
    "### 5.1 Training Configuration\n",
    "\n",
    "```python\n",
    "from dataclasses import dataclass\n",
    "\n",
    "@dataclass\n",
    "class ModelArgs:\n",
    "    # ...\n",
    "\n",
    "    # Training hyperparameters\n",
    "    max_epochs: int = 80\n",
    "    batch_size: int = 96\n",
    "    learning_rate: float = 1e-4\n",
    "    weight_decay: float = 1e-2\n",
    "    optimizer_type: str = \"adamw\"  # 'adam' or 'adamw'\n",
    "    \n",
    "    # Scheduler\n",
    "    scheduler_type: str = \"cosine\"  # 'cosine', 'step', or 'none'\n",
    "    scheduler_warmup_epochs: int = 5\n",
    "    \n",
    "    # Loss function\n",
    "    loss_type: str = \"focal\"  # 'ce' or 'focal'\n",
    "    focal_alpha: float = 0.1   # Weight for positive class\n",
    "    focal_gamma: float = 2.0   # Focusing parameter\n",
    "    \n",
    "    # Pairwise ranking loss\n",
    "    enable_pairwise: bool = True\n",
    "    pairwise_margin: float = 1.0\n",
    "    pairwise_weight: float = 0.3\n",
    "    \n",
    "    # Early stopping\n",
    "    early_stopping_patience: int = 15\n",
    "    early_stopping_metric: str = \"eer\"  # 'eer', 'f1_macro', 'accuracy'\n",
    "    early_stopping_mode: str = \"min\"    # 'min' for eer, 'max' for f1/acc\n",
    "\n",
    "    # ...\n",
    "```\n",
    "\n",
    "### 5.2 Focal Loss for Class Imbalance\n",
    "\n",
    "Focal Loss addresses class imbalance by down-weighting easy examples:\n",
    "\n",
    "```python\n",
    "class FocalLoss(nn.Module):\n",
    "    \"\"\"\n",
    "    FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)\n",
    "    \n",
    "    Args:\n",
    "        alpha: Weighting factor [num_classes]\n",
    "        gamma: Focusing parameter (default: 2.0)\n",
    "    \"\"\"\n",
    "    def __init__(self, alpha: torch.Tensor, gamma: float = 2.0):\n",
    "        super().__init__()\n",
    "        self.alpha = alpha\n",
    "        self.gamma = gamma\n",
    "    \n",
    "    def forward(self, logits, labels):\n",
    "        probs = F.softmax(logits, dim=1)\n",
    "        probs_t = (probs * F.one_hot(labels, num_classes=2).float()).sum(dim=1)\n",
    "        \n",
    "        alpha_t = (self.alpha * F.one_hot(labels, num_classes=2).float()).sum(dim=1)\n",
    "        focal_weight = alpha_t * (1 - probs_t) ** self.gamma\n",
    "        ce_loss = F.cross_entropy(logits, labels, reduction='none')\n",
    "        \n",
    "        return (focal_weight * ce_loss).mean()\n",
    "\n",
    "# Usage\n",
    "focal_alpha = torch.tensor([0.9, 0.1])  # [spoof_weight, bonafide_weight]\n",
    "criterion = FocalLoss(focal_alpha, gamma=2.0)\n",
    "```\n",
    "\n",
    "**Why Focal Loss?**\n",
    "- ASVspoof5 training data is imbalanced (more spoof samples)\n",
    "- Focal Loss focuses on hard-to-classify examples\n",
    "\n",
    "### 5.3 Pairwise Ranking Loss\n",
    "\n",
    "Optimizes ranking-based metrics (EER, minDCF):\n",
    "\n",
    "```python\n",
    "class PairwiseRankingLoss(nn.Module):\n",
    "    \"\"\"\n",
    "    Encourages bonafide samples to score higher than spoof samples\n",
    "    Loss = max(0, margin - (score_bonafide - score_spoof))\n",
    "    \"\"\"\n",
    "    def __init__(self, margin: float = 1.0):\n",
    "        super().__init__()\n",
    "        self.margin = margin\n",
    "    \n",
    "    def forward(self, logits, labels):\n",
    "        scores = logits[:, 1]  # Bonafide scores\n",
    "        \n",
    "        bonafide_scores = scores[labels == 1]\n",
    "        spoof_scores = scores[labels == 0]\n",
    "        \n",
    "        # Create pairwise differences\n",
    "        score_diff = bonafide_scores[:, None] - spoof_scores[None, :]\n",
    "        pairwise_loss = F.relu(self.margin - score_diff)\n",
    "        \n",
    "        return pairwise_loss.mean()\n",
    "\n",
    "# Combined loss\n",
    "total_loss = focal_loss + 0.3 * pairwise_loss\n",
    "```\n",
    "\n",
    "### 5.4 Optimizer and Scheduler\n",
    "\n",
    "```python\n",
    "# AdamW optimizer with weight decay\n",
    "optimizer = optim.AdamW(\n",
    "    model.parameters(),\n",
    "    lr=1e-4,\n",
    "    weight_decay=1e-2\n",
    ")\n",
    "\n",
    "# Cosine annealing scheduler with linear warmup\n",
    "scheduler = optim.lr_scheduler.CosineAnnealingLR(\n",
    "    optimizer,\n",
    "    T_max=max_epochs - warmup_epochs,\n",
    "    eta_min=1e-6\n",
    ")\n",
    "\n",
    "# Linear warmup for first 5 epochs\n",
    "for epoch in range(1, warmup_epochs + 1):\n",
    "    warmup_lr = learning_rate * epoch / warmup_epochs\n",
    "    for param_group in optimizer.param_groups:\n",
    "        param_group['lr'] = warmup_lr\n",
    "```\n",
    "\n",
    "### 5.5 Training Loop\n",
    "\n",
    "```python\n",
    "from utils import EarlyStopping, compute_all_metrics\n",
    "\n",
    "# Early stopping\n",
    "early_stopping = EarlyStopping(patience=15, mode='min')\n",
    "\n",
    "best_eer = float('inf')\n",
    "\n",
    "for epoch in range(1, max_epochs + 1):\n",
    "    # Training\n",
    "    model.train()\n",
    "    for batch in train_loader:\n",
    "        waveforms = batch['waveforms'].to(device)\n",
    "        labels = batch['labels'].to(device)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        logits = model(waveforms)\n",
    "        loss = criterion(logits, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    # Validation with TTA\n",
    "    model.eval()\n",
    "    all_logits, all_labels = [], []\n",
    "    with torch.no_grad():\n",
    "        for batch in dev_loader:\n",
    "            waveforms = batch['waveforms'].to(device)  # [B, num_crops, C, T]\n",
    "            labels = batch['labels'].to(device)\n",
    "            \n",
    "            # TTA: reshape and average\n",
    "            B, num_crops, C, T = waveforms.shape\n",
    "            waveforms_flat = waveforms.view(B * num_crops, C, T)\n",
    "            logits_flat = model(waveforms_flat)\n",
    "            logits = logits_flat.view(B, num_crops, 2).mean(dim=1)\n",
    "            \n",
    "            all_logits.append(logits.cpu())\n",
    "            all_labels.append(labels.cpu())\n",
    "    \n",
    "    # Compute metrics\n",
    "    all_logits = torch.cat(all_logits)\n",
    "    all_labels = torch.cat(all_labels)\n",
    "    metrics = compute_all_metrics(all_logits, all_labels)\n",
    "    \n",
    "    print(f\"Epoch {epoch}: EER={metrics['eer']:.4f}, minDCF={metrics['min_dcf']:.4f}\")\n",
    "    \n",
    "    # Save best model\n",
    "    if metrics['eer'] < best_eer:\n",
    "        best_eer = metrics['eer']\n",
    "        torch.save(model.state_dict(), 'best_model.pt')\n",
    "    \n",
    "    # Early stopping\n",
    "    if early_stopping(metrics['eer']):\n",
    "        print(f\"Early stopping at epoch {epoch}\")\n",
    "        break\n",
    "    \n",
    "    # Update scheduler\n",
    "    if epoch > warmup_epochs:\n",
    "        scheduler.step()\n",
    "```"
   ],
   "metadata": {}
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## 6. Evaluation Pipeline\n",
    "\n",
    "### 6.1 Overview\n",
    "\n",
    "The evaluation pipeline computes standard metrics (EER, minDCF, actDCF) and applies Platt calibration for score normalization:\n",
    "\n",
    "```python\n",
    "from utils import compute_all_metrics, apply_platt_calibration\n",
    "\n",
    "# Get predictions\n",
    "dev_logits, dev_labels = evaluate_model(model, dev_loader, device)\n",
    "eval_logits, eval_labels = evaluate_model(model, eval_loader, device)\n",
    "\n",
    "# Apply Platt calibration (fit on dev, apply to eval)\n",
    "calibrated_scores, _ = apply_platt_calibration(\n",
    "    dev_logits, dev_labels, eval_logits\n",
    ")\n",
    "\n",
    "# Compute metrics\n",
    "metrics = compute_metrics_from_scores(calibrated_scores, eval_labels)\n",
    "print(f\"EER: {metrics['eer']:.4f}, minDCF: {metrics['min_dcf']:.4f}\")\n",
    "```\n",
    "\n",
    "### 6.2 ASVspoof5 Track 1 Metrics\n",
    "\n",
    "**minDCF (Minimum Detection Cost Function):**\n",
    "```python\n",
    "def compute_min_dcf(scores, labels, c_miss=1.0, c_fa=10.0, p_target=0.05):\n",
    "    \"\"\"\n",
    "    Normalized minimum DCF for ASVspoof5 Track 1\n",
    "    Parameters: C_miss=1.0, C_fa=10.0, π_spf=0.05\n",
    "    \"\"\"\n",
    "    fpr, tpr, _ = roc_curve(labels, scores, pos_label=1)\n",
    "    fnr = 1 - tpr\n",
    "    \n",
    "    dcf = c_miss * fnr * p_target + c_fa * fpr * (1 - p_target)\n",
    "    dcf_def = min(c_miss * p_target, c_fa * (1 - p_target))\n",
    "    \n",
    "    return np.min(dcf / dcf_def)\n",
    "```\n",
    "\n",
    "**actDCF (Actual Detection Cost Function):**\n",
    "Computed at Bayes-optimal threshold: τ_bayes = -log(β), where β ≈ 1.90 for ASVspoof5 Track\n",
    "```python\n",
    "def compute_act_dcf(\n",
    "    scores: np.ndarray,\n",
    "    labels: np.ndarray,\n",
    "    c_miss: float = 1.0,\n",
    "    c_fa: float = 10.0,\n",
    "    p_target: float = 0.05\n",
    ") -> float:\n",
    "    \"\"\"\n",
    "    Compute actual Detection Cost Function (actDCF) at Bayes threshold for ASVspoof5 Track 1\n",
    "\n",
    "    Following ASVspoof5 specification:\n",
    "    - τ_bayes = -log(β) where β = C_miss * (1 - π_spf) / (C_fa * π_spf) ≈ 1.90\n",
    "    - actDCF = DCF'(τ_bayes) (normalized actual DCF at Bayes-optimal threshold)\n",
    "\n",
    "    Note: This assumes detection scores can be interpreted as log-likelihood ratios.\n",
    "    If scores are probabilities, conversion may be needed.\n",
    "\n",
    "    Args:\n",
    "        scores: Prediction scores (higher = more likely bonafide)\n",
    "        labels: Ground truth labels (0=spoof, 1=bonafide)\n",
    "        c_miss: Cost of missing a bonafide (false negative), default=1.0\n",
    "        c_fa: Cost of false alarm on spoof (false positive), default=10.0\n",
    "        p_target: Prior probability of bonafide (1 - π_spf), default=0.05\n",
    "\n",
    "    Returns:\n",
    "        act_dcf_normalized: Normalized actual DCF at Bayes threshold\n",
    "    \"\"\"\n",
    "    # Compute β (beta factor)\n",
    "    beta = (c_miss * p_target) / (c_fa * (1 - p_target))\n",
    "\n",
    "    # Bayes-optimal threshold τ_bayes = -log(β)\n",
    "    # For probability scores in [0,1], we need to convert to log-likelihood ratios\n",
    "    # Since scores are probabilities P(bonafide|x), we compute log-odds\n",
    "    eps = 1e-10\n",
    "    scores_clipped = np.clip(scores, eps, 1 - eps)\n",
    "\n",
    "    # Convert probability scores to log-likelihood ratios\n",
    "    # LLR = log(P(bonafide|x) / P(spoof|x))\n",
    "    llr_scores = np.log(scores_clipped / (1 - scores_clipped))\n",
    "\n",
    "    # Bayes threshold in log-likelihood ratio space\n",
    "    tau_bayes = -np.log(beta)\n",
    "\n",
    "    # Make predictions at Bayes threshold\n",
    "    predictions = (llr_scores >= tau_bayes).astype(int)\n",
    "\n",
    "    # Compute confusion matrix elements\n",
    "    tp = np.sum((predictions == 1) & (labels == 1))\n",
    "    fp = np.sum((predictions == 1) & (labels == 0))\n",
    "    fn = np.sum((predictions == 0) & (labels == 1))\n",
    "    tn = np.sum((predictions == 0) & (labels == 0))\n",
    "\n",
    "    # Compute error rates\n",
    "    fnr = fn / (tp + fn + 1e-10)  # P_miss (miss rate for bonafide)\n",
    "    fpr = fp / (fp + tn + 1e-10)  # P_fa (false alarm rate for spoof)\n",
    "\n",
    "    # Compute unnormalized actual DCF\n",
    "    act_dcf = c_miss * fnr * p_target + c_fa * fpr * (1 - p_target)\n",
    "\n",
    "    # Normalize by DCF_def\n",
    "    dcf_def = min(c_miss * p_target, c_fa * (1 - p_target))\n",
    "    act_dcf_normalized = act_dcf / dcf_def\n",
    "\n",
    "    return act_dcf_normalized\n",
    "```\n",
    "\n",
    "### 6.3 Complete Evaluation\n",
    "\n",
    "```python\n",
    "def evaluate_with_calibration(model, dev_loader, eval_loader, device):\n",
    "    \"\"\"Evaluation with calibration and prior correction\"\"\"\n",
    "    # Get predictions\n",
    "    dev_logits, dev_labels = evaluate_model(model, dev_loader, device, use_tta=True)\n",
    "    eval_logits, eval_labels = evaluate_model(model, eval_loader, device, use_tta=True)\n",
    "    \n",
    "    # Apply calibration\n",
    "    calibrated_scores, _ = apply_platt_calibration(dev_logits, dev_labels, eval_logits)\n",
    "    corrected_scores = apply_prior_correction(dev_labels, eval_labels, calibrated_scores)\n",
    "    \n",
    "    # Compute metrics\n",
    "    metrics = compute_metrics_from_scores(corrected_scores, eval_labels)\n",
    "    metrics['cllr'] = compute_cllr(corrected_scores, eval_labels)\n",
    "    \n",
    "    return metrics\n",
    "```\n",
    "\n",
    "For detailed implementation of calibration methods and CLLR computation, see `utils.py`."
   ],
   "id": "703421f44bac17d5"
  },
  {
   "cell_type": "markdown",
   "id": "ykxnj6a3k3a",
   "source": [
    "## 7. Project Structure\n",
    "\n",
    "```\n",
    "true_tone3/\n",
    "│\n",
    "├── model.py                    # Transformer model architecture\n",
    "│   ├── SpeechClassifierArgs    # Model configuration dataclass\n",
    "│   ├── PositionalEncoding      # Sinusoidal positional encoding\n",
    "│   ├── SpeechTransformerClassifier  # Main model class\n",
    "│   └── create_model()          # Model factory function\n",
    "│\n",
    "├── data_process.py             # Data loading and preprocessing\n",
    "│   ├── DefaultArgs             # Data loading configuration\n",
    "│   ├── RawBoost                # RawBoost augmentation class\n",
    "│   ├── read_protocol()         # Parse ASVspoof5 protocol files\n",
    "│   ├── ASV5Dataset             # PyTorch Dataset class\n",
    "│   ├── TTADataset              # Test-Time Augmentation wrapper\n",
    "│   ├── collate_fn()            # Batch collation function\n",
    "│   └── make_loaders()          # DataLoader creation\n",
    "│\n",
    "├── utils.py                          # Utility functions\n",
    "│   ├── set_seed()                    # Random seed fixing\n",
    "│   ├── get_device()                  # Device management\n",
    "│   ├── FocalLoss                     # Focal loss for class imbalance\n",
    "│   ├── PairwiseRankingLoss           # Pairwise ranking loss\n",
    "│   ├── CombinedLoss                  # Combined loss wrapper\n",
    "│   ├── compute_eer()                 # Equal Error Rate\n",
    "│   ├── compute_min_dcf()             # Minimum Detection Cost Function\n",
    "│   ├── compute_cllr()                # Calibrated Log-Likelihood Ratio\n",
    "│   ├── apply_platt_calibration()     # Platt calibration\n",
    "│   ├── apply_prior_correction()      # Prior correction\n",
    "│   ├── evaluate_model()              # Model evaluation\n",
    "│   ├── evaluate_with_calibration()   # Complete evaluation pipeline\n",
    "│   ├── load_model_weights()          # Model checkpoint loading\n",
    "│   ├── save_model()                  # Model checkpoint saving\n",
    "│   └── EarlyStopping                 # Early stopping handler\n",
    "│\n",
    "├── main_train.py                     # Main training script\n",
    "│   ├── ModelArgs                     # Complete training configuration\n",
    "│   ├── train_one_epoch()             # Training loop for one epoch\n",
    "│   ├── validate()                    # Validation with TTA\n",
    "│   └── main()                        # Main training pipeline\n",
    "│\n",
    "├── read_and_evaluate.py              # Model evaluation script\n",
    "│   ├── DatasetConfig                 # Dataset configuration\n",
    "│   ├── EvaluationConfig              # Evaluation configuration\n",
    "│   ├── create_dataloader()           # Create single dataloader\n",
    "│   ├── evaluate_dataset()            # Evaluate on single dataset\n",
    "│   └── main()                        # Main evaluation pipeline\n",
    "│\n",
    "├── run_multiple_experiments.py # Hyperparameter tuning script\n",
    "│   └── Grid search over hyperparameters\n",
    "│\n",
    "├── requirements.txt                  # Python dependencies\n",
    "│\n",
    "├── README.md                         # Quick start guide\n",
    "│\n",
    "└── Introduction_of_true_tone5.ipynb  # This file\n",
    "```\n",
    "\n",
    "### 7.1 Core Modules\n",
    "\n",
    "**model.py:**\n",
    "- Complete Transformer architecture\n",
    "- In-model mel spectrogram computation\n",
    "- Flexible pooling strategies\n",
    "- ~4.8M trainable parameters\n",
    "\n",
    "**data_process.py:**\n",
    "- ASVspoof5 protocol parsing\n",
    "- Audio loading and preprocessing\n",
    "- RawBoost augmentation\n",
    "- Test-Time Augmentation support\n",
    "\n",
    "**utils.py:**\n",
    "- Comprehensive evaluation metrics\n",
    "- Focal Loss + Pairwise Ranking Loss\n",
    "- Platt calibration and prior correction\n",
    "\n",
    "### 7.2 Scripts\n",
    "\n",
    "**main_train.py:**\n",
    "- End-to-end training pipeline\n",
    "- Automatic device selection\n",
    "- Early stopping\n",
    "- Best model saving\n",
    "- Final evaluation with calibration\n",
    "\n",
    "**read_and_evaluate.py:**\n",
    "- Flexible evaluation on multiple datasets\n",
    "- Automatic calibration (uses Dev as reference)\n",
    "- Supports TTA\n",
    "- Detailed metrics reporting\n",
    "\n",
    "**run_multiple_experiments.py:**\n",
    "- Grid search over hyperparameters\n",
    "- Automated experiment tracking\n",
    "- Parallel experiment support"
   ],
   "metadata": {}
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## 8. Technical Summary\n\n### 8.1 Key Innovations\n\n**1. Focal Loss for Class Imbalance**\n- Addresses class imbalance in ASVspoof5 training data\n- Down-weights easy examples, focuses on hard-to-classify samples\n- Focal Loss: `FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)`\n- Significant improvement over standard Cross-Entropy\n- Hyperparameters: α=0.1 (bonafide weight), γ=2.0 (focusing parameter)\n\n**2. Pairwise Ranking Loss**\n- Directly optimizes ranking-based metrics (EER/minDCF)\n- Encourages bonafide samples to score higher than spoof samples\n- Loss: `L_pairwise = max(0, margin - (score_bonafide - score_spoof))`\n- Combined with Focal Loss: `L_total = L_focal + λ * L_pairwise`\n- Hyperparameters: margin=1.0, λ=0.3 (pairwise weight)\n\n**3. Attention Pooling**\n- Learned attention mechanism for frame-level aggregation\n- Automatically weights informative frames higher than uninformative frames\n- More effective than mean pooling or top-k pooling\n- Attention scores: `α_t = softmax(w^T * tanh(W * h_t))`\n- Aggregated representation: `h = Σ α_t * h_t`\n\n**4. Test-Time Augmentation**\n- Generate 5 overlapping crops per sample during inference\n- Average logits across crops for robust predictions\n- Reduces prediction variance caused by random cropping\n- ~5% improvement in EER\n\n**5. RawBoost Augmentation**\n- Three augmentation algorithms (convolution, filtering, noise)\n- Applied during training only with 50% probability\n- Improves generalization to unseen attacks and codec variations\n- Prevents overfitting to training data distribution\n\n### 8.2 Architecture Highlights\n\n**Transformer Configuration:**\n```\nInput: [B, 1, 64000] → 16kHz, 4-second waveform\nFrontend: Log-Mel Spectrogram [B, T', 160]\nEmbedding: Linear projection [B, T', 256]\nPositional: Sinusoidal encoding\nBackbone: 6-layer Transformer (8 heads, dim=256)\nPooling: Mean/Attention/Top-k (Attention is used in TFPARN)\nClassifier: 2-layer MLP → [B, 2] logits\n```\n\n**Model Size:**\n- Parameters: ~4.8M trainable\n- VRAM: 8-12GB (batch size dependent)\n\n**Training Time:**\n- 1 epoch(Training + Validating with TTA): ~5 minutes (RTX 5090)\n- Full training: 4-8 hours (with early stopping)\n- Batch size 96: optimal for 10GB VRAM\n\n### 8.3 Advantages and Limitations\n\n**Advantages:**\n\n✓ End-to-end trainable (no preprocessing)\n\n✓ Strong generalization with Focal Loss\n\n✓ Calibrated probability outputs\n\n✓ Flexible pooling strategies\n\n✓ TTA for improved robustness\n\n✓ Complete evaluation pipeline\n\n\n**Limitations:**\n\n✗ Requires GPU for efficient training (CPU can be used, but very slow)\n\n✗ Memory-intensive for large batch sizes\n\n✗ No explicit codec awareness",
   "id": "6a3328e5fc47d46b"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## 10. References\n",
    "\n",
    "### 10.1 Datasets\n",
    "\n",
    "- ASVspoof 2021 Dataset: https://www.kaggle.com/datasets/mohammedabdeldayem/avsspoof-2021\n",
    "- ASVspoof 2019 Database: https://www.kaggle.com/datasets/awsaf49/asvpoof-2019-dataset\n",
    "- ASVspoof 5: https://zenodo.org/records/14498691\n",
    "\n",
    "### 10.2 Key Papers\n",
    "\n",
    "1. **RawBoost:**\n",
    "   - \"RawBoost: A Raw Data Boosting and Augmentation Method applied to Automatic Speaker Verification Anti-Spoofing\" (ICASSP 2022)\n",
    "   - Three augmentation algorithms (implemented in this project)\n",
    "   - Paper: https://arxiv.org/abs/2111.04433\n",
    "\n",
    "2. **Focal Loss:**\n",
    "   - \"Focal Loss for Dense Object Detection\" (ICCV 2017)\n",
    "   - Addresses class imbalance by down-weighting easy examples\n",
    "   - Paper: https://arxiv.org/abs/1708.02002\n",
    "\n",
    "3. **Attention Is All You Need:**\n",
    "   - \"Attention Is All You Need\" (NeurIPS 2017)\n",
    "   - Original Transformer architecture\n",
    "   - Paper: https://arxiv.org/abs/1706.03762\n",
    "\n",
    "4. **Platt Calibration:**\n",
    "   - \"Probabilistic Outputs for Support Vector Machines and Comparisons to Regularized Likelihood Methods\" (1999)\n",
    "   - Sigmoid-based probability calibration method\n",
    "   - Paper: https://www.cs.colorado.edu/~mozer/Teaching/syllabi/6622/papers/Platt1999.pdf\n",
    "\n",
    "### 10.3 Metrics and Evaluation\n",
    "\n",
    "**EER (Equal Error Rate):**\n",
    "- Standard metric for biometric systems\n",
    "- Point where FPR = FNR\n",
    "- Lower is better\n",
    "\n",
    "**minDCF (Minimum Detection Cost Function):**\n",
    "- Weighted combination of miss and false alarm rates\n",
    "- Standard in speaker verification\n",
    "- Formula: `DCF = C_miss * P_miss * P_target + C_fa * P_fa * (1 - P_target)`\n",
    "\n",
    "**CLLR (Calibrated Log-Likelihood Ratio):**\n",
    "- Measures probability calibration quality\n",
    "- CLLR = 0: Perfect calibration\n",
    "- CLLR > 1: Poor calibration\n",
    "- Reference: \"Application-Independent Evaluation of Speaker Detection\" (Computer Speech & Language 2006)\n",
    "- Paper: https://www.sciencedirect.com/science/article/abs/pii/S0885230805000306\n"
   ],
   "id": "156e84d21601b91e"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
